-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set shape basis method #276
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## development #276 +/- ##
=============================================
Coverage 96.13% 96.13%
=============================================
Files 34 34
Lines 2507 2642 +135
=============================================
+ Hits 2410 2540 +130
- Misses 97 102 +5 ☔ View full report in Codecov by Sentry. |
In the next PR I'll improve the API of the transformer and move the TransformerBasis tests in a dedicated script. If you want an overview on how to work with the TransformerBasis before digging into the code, checkout this https://nemos--276.org.readthedocs.build/en/276/how_to_guide/plot_05_transformer_basis.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! I just have a couple of suggested changes to clarify in documentation/error messages where set_input_shape
will have an impact. Also, some code in a Warning admonition in plot_06_sklearn_pipeline_cv_demo needs to be updated with the new syntax
Co-authored-by: Sarah Jo Venditto <[email protected]>
Co-authored-by: Sarah Jo Venditto <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't had a chance to look at tests, but here's my first go through. Some tweaking in the tutorials, and also:
- I think we can probably move away from using numbers in the names of the tutorials. The only reason to do so is for the automatic sorting to work, but we're no longer making use of automatic sorting, right?
- I am confused with what lives in the
Basis
superclass and what lives inAtomicBasisMixin
. For example, why does anything to do with n_basis_input or n_basis_funcs live in Basis?
src/nemos/basis/_basis_mixin.py
Outdated
from ._basis import Basis | ||
|
||
|
||
def set_input_shape_state(method): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not make this more general and accept a list of attributes to copy, which default to ["_n_basis_input_", "_input_shape_"]
?
src/nemos/basis/_basis_mixin.py
Outdated
@property | ||
def basis1(self): | ||
return self._basis1 | ||
|
||
@basis1.setter | ||
def basis1(self, bas: Basis): | ||
self._basis1 = bas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are these properties if we do nothing special in the getter or setter? why not just use self.basis1
instead of self._basis1
?
Add a
set_input_shape
method that initializes the basis.The method can accept a list of:
It prepares the basis and precomputes some internal quantities that are useful for subsequent functionalities: knowing how to split the feature axis, i.e.
split_by_features
. In the next PR,TransformerBasis
will use the pre-computed shapes to split out the concatenated inputs before processing.Additionally, it will compute and store the number of output features, this could be an information that a user may want to have after building a complex composite basis.
This PR follows #275
EDIT:
PR Summary
In this PR, I refined the class structure to better separate basis attributes and methods, delegating their validation logic to mixin classes wherever possible.
New Classes
CompositeBasisMixin
:This mixin is inherited by additive and multiplicative bases. It implements methods for traversing the composite basis tree, such as
__sklearn_clone__
andsetup_basis
.AtomicBasisMixin
:Designed for non-composite (atomic) bases, this mixin stores the
n_basis_funcs
parameter and implements selected methods like__sklearn_clone__
, which have uniform implementations across all atomic bases.New Abstract Methods
set_input_shape
:This method stores the expected input shape, a state attribute required for compatibility with transformers. Parameters set by this method are carried over during cloning, such as in cross-validation. Concrete implementations are provided in
AtomicBasisMixin
andCompositeBasisMixin
._set_input_independent_states
:Responsible for setting all state variables that depend on class parameters (provided at initialization and retrievable with
get_params
).setup_basis
:Computes all state variables, both input-dependent (e.g., input shape) and input-independent (e.g., kernels for convolutional bases). Concrete implementations are found in
Eval
,Conv
, andComposite
basis mixins.Clone Method for Bases
__sklearn_clone__
:Implements cloning logic to retain input-dependent states (such as input shape), which would otherwise be discarded by
sklearn.base.clone
. This is implemented in both theCompositeBasisMixin
andAtomicBasisMixin
.Modified/Moved Attributes and Methods
Kernel-related logic:
The
kernel_
attribute, along with the_check_has_kernel
method andset_kernel
, has been moved to theConv
mixin.Input shape validation:
The
_check_input_shape_consistency
method has been relocated toAtomicBasisMixin
andCompositeBasisMixin
.Composite basis setters:
Setters for
basis1
andbasis2
in composite bases now support cross-validation scenarios.Inheritance of New Mixins
CompositeBasisMixin
:Inherited by
AdditiveBasis
andMultiplicativeBasis
.AtomicBasisMixin
:Inherited by the following classes:
SplineBasis
(the superclass for all splines)RaisedCosineLinear
(the superclass for all raised cosines)OrthExponentialBasis
(the superclass for orthogonal exponential bases)